001    package com.mockrunner.base;
002    
003    import java.lang.reflect.Constructor;
004    import java.util.ArrayList;
005    import java.util.Enumeration;
006    import java.util.List;
007    
008    import junit.framework.Test;
009    import junit.framework.TestCase;
010    import junit.framework.TestResult;
011    import junit.framework.TestSuite;
012    
013    import org.apache.commons.logging.Log;
014    import org.apache.commons.logging.LogFactory;
015    
016    /**
017     * This TestSuite starts all test methods in
018     * a number of seperate threads. <b>Doesn't
019     * work properly yet. Do not use it :-)</b>
020     */
021    public class MultiThreadTestSuite extends TestSuite
022    {
023            private final static Log log = LogFactory.getLog(MultiThreadTestSuite.class);
024        private int numberThreads;
025            private boolean doClone;
026            
027            public MultiThreadTestSuite(Class theClass, String name)
028            {
029                    this(theClass, name, 5, true);
030            }
031            
032            public MultiThreadTestSuite(Class theClass)
033            {
034                    this(theClass, 5, true);
035            }
036            
037            public MultiThreadTestSuite(Class theClass, String name, int numberThreads, boolean doClone)
038            {
039                    super(theClass, name);
040                    this.numberThreads = numberThreads;
041                    this.doClone = doClone;
042            }
043    
044            public MultiThreadTestSuite(Class theClass, int numberThreads, boolean doClone)
045            {
046                    super(theClass);
047                    this.numberThreads = numberThreads;
048                    this.doClone = doClone;
049            }
050            
051            public void run(TestResult result)
052            {
053                    Enumeration tests = tests();
054                    while(tests.hasMoreElements()) 
055                    {
056                            if(result.shouldStop()) return;
057                            TestCase currentTest = (TestCase)tests.nextElement();   
058                            List threads = createThreadListForTest(currentTest, result);
059                            runAllThreadsForTest(threads);
060                    }
061            }
062            
063            private Test createNewTestInstanceBasedOn(TestCase test)
064            {
065                    TestCase newTest = null;
066                    try
067                    {       
068                            Constructor constructor = getTestConstructor(test.getClass());
069                            if (constructor.getParameterTypes().length == 0) 
070                            {
071                                    newTest = (TestCase)constructor.newInstance(new Object[0]);
072                                    newTest.setName(test.getName());
073                                                                    
074                            } 
075                            else 
076                            {
077                                    newTest = (TestCase)constructor.newInstance(new Object[]{test.getName()});
078                            }
079                    }
080                    catch(Exception exc)
081                    {
082                            log.error(exc.getMessage(), exc);
083                    }
084                    return newTest;
085            }
086    
087            private List createThreadListForTest(TestCase currentTest, TestResult result)
088            {
089                    ArrayList threads = new ArrayList(numberThreads);
090                    for(int ii = 0; ii < numberThreads; ii++)
091                    {
092                            Test newTest = currentTest;
093                            if(doClone) newTest = createNewTestInstanceBasedOn(currentTest);
094                            TestThread thread = new TestThread("TestThread " + ii, newTest, result);
095                            threads.add(thread);
096                    }
097                    return threads;
098            }
099    
100            private void runAllThreadsForTest(List threads)
101            {
102                    for(int ii = 0; ii < threads.size(); ii++)
103                    {
104                            Thread thread = (Thread)threads.get(ii);
105                            thread.start();
106                    }
107                    for(int ii = 0; ii < threads.size(); ii++)
108                    {
109                            Thread thread = (Thread)threads.get(ii);
110                            try
111                            {
112                                    thread.join();
113                            }
114                            catch(InterruptedException exc)
115                            {
116                                    log.error("Interrupted", exc);
117                            }
118                    }
119            }
120            
121            private static class TestThread extends Thread
122            {
123                    private Test test;
124                    private TestResult result;
125                    
126                    public TestThread(String name, Test test, TestResult result)
127                    {
128                            super(name);
129                            this.test = test;
130                            this.result = result;
131                    }
132                    
133                    public void run()
134                    {
135                            test.run(result);
136                    }
137            }
138    }